Skip to content

[bugfix] fix deepseek rope sincoscache re-generation#2744

Merged
MengqingCao merged 1 commit intovllm-project:mainfrom
zzzzwwjj:main
Sep 8, 2025
Merged

[bugfix] fix deepseek rope sincoscache re-generation#2744
MengqingCao merged 1 commit intovllm-project:mainfrom
zzzzwwjj:main

Conversation

@zzzzwwjj
Copy link
Copy Markdown
Collaborator

@zzzzwwjj zzzzwwjj commented Sep 4, 2025

What this PR does / why we need it?

The current implementation will result in duplicate generation of sin_cos_cache in rope when kv_seqlen > 4k, because the initialization length of the sin_cos_cache is only 4k.

Does this PR introduce any user-facing change?

No.

How was this patch tested?

After this PR merged, sin_cos_cache will not increase in forward func, so test_native_rope_deepseek_forward_cache_handling is not necessary.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a bug in the DeepSeek RoPE implementation where the sincoscache was being regenerated unnecessarily for sequences longer than the initial max_position_embeddings. The fix involves pre-allocating a larger cache during initialization by incorporating the scaling_factor, and removing the dynamic resizing logic from the forward pass. These changes are consistently applied across both the standard and torchair implementations, and the corresponding obsolete tests are correctly removed. My review identifies a minor but important point for robustness: the calculation of max_seq_len results in a float, which could lead to implicit behavior dependencies. I've suggested using math.ceil to ensure it's an integer, which improves clarity and aligns with the expectations of parent classes.

Comment thread vllm_ascend/ops/rotary_embedding.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The scaling_factor is a float, which results in self.max_seq_len being a float. While torch.arange can handle a float end value, it's safer and clearer to use an integer for a value representing a sequence length. This avoids reliance on the implicit behavior of torch.arange with floats and improves code robustness. The parent class RotaryEmbedding also expects an int for the sequence length in its _set_cos_sin_cache method signature. Using math.ceil will ensure the allocated cache is large enough and the length is an integer.

Suggested change
self.max_seq_len = max_position_embeddings * scaling_factor
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The scaling_factor is a float, making self.max_seq_len a float. It is better practice to use an integer for sequence lengths to avoid relying on the implicit behavior of torch.arange with float inputs and to make the code more explicit and robust. The corresponding method in the parent RotaryEmbedding class also expects an integer. Using math.ceil ensures the length is an integer and the cache size is sufficient.

Suggested change
self.max_seq_len = max_position_embeddings * scaling_factor
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Sep 4, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@codecov
Copy link
Copy Markdown

codecov Bot commented Sep 4, 2025

Codecov Report

❌ Patch coverage is 91.66667% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.78%. Comparing base (24d4dad) to head (45dfe83).
⚠️ Report is 728 commits behind head on main.

Files with missing lines Patch % Lines
...m_ascend/torchair/ops/torchair_rotary_embedding.py 50.00% 2 Missing ⚠️
.../ut/torchair/ops/test_torchair_rotary_embedding.py 93.75% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2744      +/-   ##
==========================================
- Coverage   73.71%   72.78%   -0.94%     
==========================================
  Files         152      154       +2     
  Lines       21967    21313     -654     
==========================================
- Hits        16194    15513     -681     
- Misses       5773     5800      +27     
Flag Coverage Δ
unittests 72.78% <91.66%> (-0.94%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@wangxiyuan
Copy link
Copy Markdown
Collaborator

this is a cherry-pick from #1551

@zzzzwwjj zzzzwwjj force-pushed the main branch 4 times, most recently from f1cdcde to 6b4f8f0 Compare September 8, 2025 03:53
Signed-off-by: zzzzwwjj <1183291235@qq.com>
@MengqingCao MengqingCao merged commit 4df8df5 into vllm-project:main Sep 8, 2025
31 of 32 checks passed
1Fire4 pushed a commit to 1Fire4/vllm-ascend that referenced this pull request Sep 9, 2025
### What this PR does / why we need it?
The current implementation will result in duplicate generation of
`sin_cos_cache` in rope when `kv_seqlen` > 4k, because the
initialization length of the `sin_cos_cache` is only 4k.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
After this PR merged, sin_cos_cache will not increase in forward func,
so `test_native_rope_deepseek_forward_cache_handling` is not necessary.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@60f0843

Signed-off-by: zzzzwwjj <1183291235@qq.com>
Signed-off-by: 1Fire4 <wangdingyi2@huawei.com>
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Sep 10, 2025
### What this PR does / why we need it?
The current implementation will result in duplicate generation of
`sin_cos_cache` in rope when `kv_seqlen` > 4k, because the
initialization length of the `sin_cos_cache` is only 4k.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
After this PR merged, sin_cos_cache will not increase in forward func,
so `test_native_rope_deepseek_forward_cache_handling` is not necessary.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@60f0843

Signed-off-by: zzzzwwjj <1183291235@qq.com>
offline893 pushed a commit to offline893/vllm-ascend that referenced this pull request Sep 16, 2025
### What this PR does / why we need it?
The current implementation will result in duplicate generation of
`sin_cos_cache` in rope when `kv_seqlen` > 4k, because the
initialization length of the `sin_cos_cache` is only 4k.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
After this PR merged, sin_cos_cache will not increase in forward func,
so `test_native_rope_deepseek_forward_cache_handling` is not necessary.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@60f0843

Signed-off-by: zzzzwwjj <1183291235@qq.com>
Signed-off-by: offline0806 <z00858301@china.huawei.com>
wangxiaoteng888 pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Sep 25, 2025
### What this PR does / why we need it?
The current implementation will result in duplicate generation of
`sin_cos_cache` in rope when `kv_seqlen` > 4k, because the
initialization length of the `sin_cos_cache` is only 4k.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
After this PR merged, sin_cos_cache will not increase in forward func,
so `test_native_rope_deepseek_forward_cache_handling` is not necessary.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@60f0843

Signed-off-by: zzzzwwjj <1183291235@qq.com>
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
### What this PR does / why we need it?
The current implementation will result in duplicate generation of
`sin_cos_cache` in rope when `kv_seqlen` > 4k, because the
initialization length of the `sin_cos_cache` is only 4k.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
After this PR merged, sin_cos_cache will not increase in forward func,
so `test_native_rope_deepseek_forward_cache_handling` is not necessary.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@60f0843

Signed-off-by: zzzzwwjj <1183291235@qq.com>
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
### What this PR does / why we need it?
The current implementation will result in duplicate generation of
`sin_cos_cache` in rope when `kv_seqlen` > 4k, because the
initialization length of the `sin_cos_cache` is only 4k.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
After this PR merged, sin_cos_cache will not increase in forward func,
so `test_native_rope_deepseek_forward_cache_handling` is not necessary.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@60f0843

Signed-off-by: zzzzwwjj <1183291235@qq.com>
NSDie pushed a commit to NSDie/vllm-ascend that referenced this pull request Nov 24, 2025
### What this PR does / why we need it?
The current implementation will result in duplicate generation of
`sin_cos_cache` in rope when `kv_seqlen` > 4k, because the
initialization length of the `sin_cos_cache` is only 4k.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
After this PR merged, sin_cos_cache will not increase in forward func,
so `test_native_rope_deepseek_forward_cache_handling` is not necessary.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@60f0843

Signed-off-by: zzzzwwjj <1183291235@qq.com>
Signed-off-by: nsdie <yeyifan@huawei.com>
Clorist33 pushed a commit to Clorist33/vllm-ascend that referenced this pull request Dec 9, 2025
### What this PR does / why we need it?
The current implementation will result in duplicate generation of
`sin_cos_cache` in rope when `kv_seqlen` > 4k, because the
initialization length of the `sin_cos_cache` is only 4k.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
After this PR merged, sin_cos_cache will not increase in forward func,
so `test_native_rope_deepseek_forward_cache_handling` is not necessary.

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@60f0843

Signed-off-by: zzzzwwjj <1183291235@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants